{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "\n",
    "def fully_connected_layer(X, W, b, activation=None):\n",
    "    \"\"\"\n",
    "    Implements a Fully Connected layer of Neural Networks\n",
    "    Input:\n",
    "    1) X: n x d tensor - each row of X is an input vector, there are n vectors\n",
    "          each of size d.\n",
    "    2) W: m x d tensor\n",
    "    Returns:\n",
    "    y: m tensor, y = W X.transpose + b\n",
    "    \"\"\"\n",
    "    assert (activation is None or activation is torch.sigmoid or activation is torch.relu\n",
    "            or activation is torch.tanh or activation is torch.heaviside)\n",
    "    assert torch.is_tensor(X) and torch.is_tensor(W) and torch.is_tensor(b)\n",
    "    assert len(X.shape) == 2\n",
    "    n = X.shape[0]  # number of input vectors\n",
    "    d = X.shape[1]  # input  dimensionality\n",
    "    m = b.shape[0]  # output dimensionality\n",
    "    assert b.shape == torch.Size([m]), \"b.shape = {}\".format(b.shape)\n",
    "    assert W.shape == torch.Size([m, d]), \"W.shape = {}\".format(W.shape)\n",
    "\n",
    "    X = torch.cat((X, torch.ones([X.shape[0], 1], dtype=torch.float32)), dim=1)\n",
    "    W = torch.cat((W, b.unsqueeze(dim=1)), dim=1)\n",
    "    y = torch.matmul(W, X.transpose(0, 1))\n",
    "    if activation is not None:\n",
    "        if activation is torch.heaviside:\n",
    "            y = activation(y, torch.tensor(1.0))\n",
    "        else:\n",
    "            y = activation(y)\n",
    "\n",
    "    return y.transpose(0, 1)\n",
    "\n",
    "\n",
    "def Perceptron(X, W, b, activation=torch.heaviside):\n",
    "    assert W.shape[0] == 1 and b.shape[0] == 1\n",
    "    return fully_connected_layer(X, W, b, activation=activation)\n",
    "\n",
    "\n",
    "def MLP(X, W0, W1, b0, b1, activation0=torch.heaviside, activation1=None):\n",
    "    y0 = fully_connected_layer(X=X, W=W0, b=b0, activation=activation0)\n",
    "    return fully_connected_layer(X=y0, W=W1, b=b1, activation=activation1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
